""" Mailu marshmallow fields and schema
"""

from copy import deepcopy
from collections import Counter
from datetime import timezone

import inspect
import json
import logging
import yaml

import sqlalchemy

from marshmallow import pre_load, post_load, post_dump, fields, Schema
from marshmallow.utils import ensure_text_type
from marshmallow.exceptions import ValidationError
from marshmallow_sqlalchemy import SQLAlchemyAutoSchemaOpts
from marshmallow_sqlalchemy.fields import RelatedList

from flask_marshmallow import Marshmallow

from cryptography.hazmat.primitives import serialization

from pygments import highlight
from pygments.token import Token
from pygments.lexers import get_lexer_by_name
from pygments.lexers.data import YamlLexer
from pygments.formatters import get_formatter_by_name

from mailu import models, dkim


ma = Marshmallow()


### import logging and schema colorization ###

_model2schema = {}

def get_schema(cls=None):
    """ return schema class for model """
    if cls is None:
        return _model2schema.values()
    return _model2schema.get(cls)

def mapped(cls):
    """ register schema in model2schema map """
    _model2schema[cls.Meta.model] = cls
    return cls

class Logger:
    """ helps with counting and colorizing
        imported and exported data
    """

    class MyYamlLexer(YamlLexer):
        """ colorize yaml constants and integers """
        def get_tokens(self, text, unfiltered=False):
            for typ, value in super().get_tokens(text, unfiltered):
                if typ is Token.Literal.Scalar.Plain:
                    if value in {'true', 'false', 'null'}:
                        typ = Token.Keyword.Constant
                    elif value == HIDDEN:
                        typ = Token.Error
                    else:
                        try:
                            int(value, 10)
                        except ValueError:
                            try:
                                float(value)
                            except ValueError:
                                pass
                            else:
                                typ = Token.Literal.Number.Float
                        else:
                            typ = Token.Literal.Number.Integer
                yield typ, value

    def __init__(self, want_color=None, can_color=False, debug=False, secrets=False):

        self.lexer = 'yaml'
        self.formatter = 'terminal'
        self.strip = False
        self.verbose = 0
        self.quiet = False
        self.secrets = secrets
        self.debug = debug
        self.print = print

        self.color = want_color or can_color

        self._counter = Counter()
        self._schemas = {}

        # log contexts
        self._diff_context = {
            'full': True,
            'secrets': secrets,
        }
        log_context = {
            'secrets': secrets,
        }

        # register listeners
        for schema in get_schema():
            model = schema.Meta.model
            self._schemas[model] = schema(context=log_context)
            sqlalchemy.event.listen(model, 'after_insert', self._listen_insert)
            sqlalchemy.event.listen(model, 'after_update', self._listen_update)
            sqlalchemy.event.listen(model, 'after_delete', self._listen_delete)

        # special listener for dkim_key changes
        # TODO: _listen_dkim can be removed when dkim keys are stored in database
        self._dedupe_dkim = set()
        sqlalchemy.event.listen(models.db.session, 'after_flush', self._listen_dkim)

        # register debug logger for sqlalchemy
        if self.debug:
            logging.basicConfig()
            logging.getLogger('sqlalchemy.engine').setLevel(logging.INFO)

    def _log(self, action, target, message=None):
        if message is None:
            try:
                message = self._schemas[target.__class__].dump(target)
            except KeyError:
                message = target
        if not isinstance(message, str):
            message = repr(message)
        self.print(f'{action} {target.__table__}: {self.colorize(message)}')

    def _listen_insert(self, mapper, connection, target): # pylint: disable=unused-argument
        """ callback method to track import """
        self._counter.update([('Created', target.__table__.name)])
        if self.verbose:
            self._log('Created', target)

    def _listen_update(self, mapper, connection, target): # pylint: disable=unused-argument
        """ callback method to track import """

        changes = {}
        inspection = sqlalchemy.inspect(target)
        for attr in sqlalchemy.orm.class_mapper(target.__class__).column_attrs:
            history = getattr(inspection.attrs, attr.key).history
            if history.has_changes() and history.deleted:
                before = history.deleted[-1]
                after = getattr(target, attr.key)
                # we don't have ordered lists
                if isinstance(before, list):
                    before = set(before)
                if isinstance(after, list):
                    after = set(after)
                # TODO: this can be removed when comment is not nullable in model
                if attr.key == 'comment' and not before and not after:
                    pass
                # only remember changed keys
                elif before != after:
                    if self.verbose:
                        changes[str(attr.key)] = (before, after)
                    else:
                        break

        if self.verbose:
            # use schema to log changed attributes
            schema = get_schema(target.__class__)
            only = set(changes.keys()) & set(schema().fields.keys())
            if only:
                for key, value in schema(
                    only=only,
                    context=self._diff_context
                ).dump(target).items():
                    before, after = changes[key]
                    if value == HIDDEN:
                        before = HIDDEN if before else before
                        after = HIDDEN if after else after
                    else:
                        # also hide this
                        after = value
                    self._log('Modified', target, f'{str(target)!r} {key}: {before!r} -> {after!r}')

        if changes:
            self._counter.update([('Modified', target.__table__.name)])

    def _listen_delete(self, mapper, connection, target): # pylint: disable=unused-argument
        """ callback method to track import """
        self._counter.update([('Deleted', target.__table__.name)])
        if self.verbose:
            self._log('Deleted', target)

    # TODO: _listen_dkim can be removed when dkim keys are stored in database
    def _listen_dkim(self, session, flush_context): # pylint: disable=unused-argument
        """ callback method to track import """
        for target in session.identity_map.values():
            # look at Domains originally loaded from db
            if not isinstance(target, models.Domain) or not target._sa_instance_state.load_path:
                continue
            before = target._dkim_key_on_disk
            after = target._dkim_key
            # "de-dupe" messages; this event is fired at every flush
            if before == after or (target, before, after) in self._dedupe_dkim:
                continue
            self._dedupe_dkim.add((target, before, after))
            self._counter.update([('Modified', target.__table__.name)])
            if self.verbose:
                if self.secrets:
                    before = before.decode('ascii', 'ignore')
                    after = after.decode('ascii', 'ignore')
                else:
                    before = HIDDEN if before else ''
                    after = HIDDEN if after else ''
                self._log('Modified', target, f'{str(target)!r} dkim_key: {before!r} -> {after!r}')

    def track_serialize(self, obj, item, backref=None):
        """ callback method to track import """
        # called for backref modification?
        if backref is not None:
            self._log(
                'Modified', item, '{target!r} {key}: {before!r} -> {after!r}'.format_map(backref))
            return
        # show input data?
        if self.verbose < 2:
            return
        # hide secrets in data
        if not self.secrets:
            item = self._schemas[obj.opts.model].hide(item)
            if 'hash_password' in item:
                item['password'] = HIDDEN
            if 'fetches' in item:
                for fetch in item['fetches']:
                    fetch['password'] = HIDDEN
        self._log('Handling', obj.opts.model, item)

    def changes(self, *messages, **kwargs):
        """ show changes gathered in counter """
        if self.quiet:
            return
        if self._counter:
            changes = []
            last = None
            for (action, what), count in sorted(self._counter.items()):
                if action != last:
                    if last:
                        changes.append('/')
                    changes.append(f'{action}:')
                    last = action
                changes.append(f'{what}({count})')
        else:
            changes = ['No changes.']
        self.print(*messages, *changes, **kwargs)

    def _format_errors(self, store, path=None):

        res = []
        if path is None:
            path = []
        for key in sorted(store):
            location = path + [str(key)]
            value = store[key]
            if isinstance(value, dict):
                res.extend(self._format_errors(value, location))
            else:
                for message in value:
                    res.append((".".join(location), message))

        if path:
            return res

        maxlen = max(len(loc) for loc, msg in res)
        res = [f'     - {loc.ljust(maxlen)} : {msg}' for loc, msg in res]
        errors = f'{len(res)} error{["s",""][len(res)==1]}'
        res.insert(0, f'[ValidationError] {errors} occurred during input validation')

        return '\n'.join(res)

    def _is_validation_error(self, exc):
        """ walk traceback to extract invalid field from marshmallow """
        path = []
        trace = exc.__traceback__
        while trace:
            if trace.tb_frame.f_code.co_name == '_serialize':
                if 'attr' in trace.tb_frame.f_locals:
                    path.append(trace.tb_frame.f_locals['attr'])
            elif trace.tb_frame.f_code.co_name == '_init_fields':
                spec = ', '.join(
                    '.'.join(path + [key])
                    for key in trace.tb_frame.f_locals['invalid_fields'])
                return f'Invalid filter: {spec}'
            trace = trace.tb_next
        return None

    def format_exception(self, exc):
        """ format ValidationErrors and other exceptions when not debugging """
        if isinstance(exc, ValidationError):
            return self._format_errors(exc.messages)
        if isinstance(exc, ValueError):
            if msg := self._is_validation_error(exc):
                return msg
        if self.debug:
            return None
        msg = ' '.join(str(exc).split())
        return f'[{exc.__class__.__name__}] {msg}'

    colorscheme = {
        Token:                  ('',        ''),
        Token.Name.Tag:         ('cyan',    'cyan'),
        Token.Literal.Scalar:   ('green',   'green'),
        Token.Literal.String:   ('green',   'green'),
        Token.Name.Constant:    ('green',   'green'), # multiline strings
        Token.Keyword.Constant: ('magenta', 'magenta'),
        Token.Literal.Number:   ('magenta', 'magenta'),
        Token.Error:            ('red',     'red'),
        Token.Name:             ('red',     'red'),
        Token.Operator:         ('red',     'red'),
    }

    def colorize(self, data, lexer=None, formatter=None, color=None, strip=None):
        """ add ANSI color to data """

        if color is False or not self.color:
            return data

        lexer = lexer or self.lexer
        lexer = Logger.MyYamlLexer() if lexer == 'yaml' else get_lexer_by_name(lexer)
        formatter = get_formatter_by_name(formatter or self.formatter, colorscheme=self.colorscheme)
        if strip is None:
            strip = self.strip

        res = highlight(data, lexer, formatter)
        if strip:
            return res.rstrip('\n')
        return res


### marshmallow render modules ###

# hidden attributes
class _Hidden:
    def __bool__(self):
        return False
    def __copy__(self):
        return self
    def __deepcopy__(self, _):
        return self
    def __eq__(self, other):
        return str(other) == '<hidden>'
    def __repr__(self):
        return '<hidden>'
    __str__ = __repr__

yaml.add_representer(
    _Hidden,
    lambda dumper, data: dumper.represent_data(str(data))
)

HIDDEN = _Hidden()

# multiline attributes
class _Multiline(str):
    pass

yaml.add_representer(
    _Multiline,
    lambda dumper, data: dumper.represent_scalar(u'tag:yaml.org,2002:str', data, style='|')

)

# yaml render module
class RenderYAML:
    """ Marshmallow YAML Render Module
    """

    class SpacedDumper(yaml.Dumper):
        """ YAML Dumper to add a newline between main sections
            and double the indent used
        """

        def write_line_break(self, data=None):
            super().write_line_break(data)
            if len(self.indents) == 1:
                super().write_line_break()

        def increase_indent(self, flow=False, indentless=False):
            return super().increase_indent(flow, False)

    @staticmethod
    def _augment(kwargs, defaults):
        """ add defaults to kwargs if missing
        """
        for key, value in defaults.items():
            if key not in kwargs:
                kwargs[key] = value

    _load_defaults = {}
    @classmethod
    def loads(cls, *args, **kwargs):
        """ load yaml data from string
        """
        cls._augment(kwargs, cls._load_defaults)
        return yaml.safe_load(*args, **kwargs)

    _dump_defaults = {
        'Dumper': SpacedDumper,
        'default_flow_style': False,
        'allow_unicode': True,
        'sort_keys': False,
    }
    @classmethod
    def dumps(cls, *args, **kwargs):
        """ dump data to yaml string
        """
        cls._augment(kwargs, cls._dump_defaults)
        return yaml.dump(*args, **kwargs)

# json encoder
class JSONEncoder(json.JSONEncoder):
    """ JSONEncoder supporting serialization of HIDDEN """
    def default(self, o):
        """ serialize HIDDEN """
        if isinstance(o, _Hidden):
            return str(o)
        return json.JSONEncoder.default(self, o)

# json render module
class RenderJSON:
    """ Marshmallow JSON Render Module
    """

    @staticmethod
    def _augment(kwargs, defaults):
        """ add defaults to kwargs if missing
        """
        for key, value in defaults.items():
            if key not in kwargs:
                kwargs[key] = value

    _load_defaults = {}
    @classmethod
    def loads(cls, *args, **kwargs):
        """ load json data from string
        """
        cls._augment(kwargs, cls._load_defaults)
        return json.loads(*args, **kwargs)

    _dump_defaults = {
        'separators': (',',':'),
        'cls': JSONEncoder,
    }
    @classmethod
    def dumps(cls, *args, **kwargs):
        """ dump data to json string
        """
        cls._augment(kwargs, cls._dump_defaults)
        return json.dumps(*args, **kwargs)


### marshmallow: custom fields ###

def _rfc3339(datetime):
    """ dump datetime according to rfc3339 """
    if datetime.tzinfo is None:
        datetime = datetime.astimezone(timezone.utc)
    res = datetime.isoformat()
    if res.endswith('+00:00'):
        return f'{res[:-6]}Z'
    return res

fields.DateTime.SERIALIZATION_FUNCS['rfc3339'] = _rfc3339
fields.DateTime.DESERIALIZATION_FUNCS['rfc3339'] = fields.DateTime.DESERIALIZATION_FUNCS['iso']
fields.DateTime.DEFAULT_FORMAT = 'rfc3339'

class LazyStringField(fields.String):
    """ Field that serializes a "false" value to the empty string
    """

    def _serialize(self, value, attr, obj, **kwargs):
        """ serialize None to the empty string
        """
        return value if value else ''

class CommaSeparatedListField(fields.Raw):
    """ Deserialize a string containing comma-separated values to
        a list of strings
    """

    default_error_messages = {
        "invalid": "Not a valid string or list.",
        "invalid_utf8": "Not a valid utf-8 string or list.",
    }

    def _deserialize(self, value, attr, data, **kwargs):
        """ deserialize comma separated string to list of strings
        """

        # empty
        if not value:
            return []

        # handle list
        if isinstance(value, list):
            try:
                value = [ensure_text_type(item) for item in value]
            except UnicodeDecodeError as exc:
                raise self.make_error("invalid_utf8") from exc

        # handle text
        else:
            if not isinstance(value, (str, bytes)):
                raise self.make_error("invalid")
            try:
                value = ensure_text_type(value)
            except UnicodeDecodeError as exc:
                raise self.make_error("invalid_utf8") from exc
            else:
                value = filter(bool, (item.strip() for item in value.split(',')))

        return list(value)


class DkimKeyField(fields.String):
    """ Serialize a dkim key to a multiline string and
        deserialize a dkim key data as string or list of strings
        to a valid dkim key
    """

    default_error_messages = {
        "invalid": "Not a valid string or list.",
        "invalid_utf8": "Not a valid utf-8 string or list.",
    }

    def _serialize(self, value, attr, obj, **kwargs):
        """ serialize dkim key as multiline string
        """

        # map empty string and None to None
        if not value:
            return ''

        # return multiline string
        return _Multiline(value.decode('utf-8'))

    def _wrap_key(self, begin, data, end):
        """ generator to wrap key into RFC 7468 format """
        yield begin
        pos = 0
        while pos < len(data):
            yield data[pos:pos+64]
            pos += 64
        yield end
        yield ''

    def _deserialize(self, value, attr, data, **kwargs):
        """ deserialize a string or list of strings to dkim key data
            with verification
        """

        # convert list to str
        if isinstance(value, list):
            try:
                value = ''.join(ensure_text_type(item) for item in value).strip()
            except UnicodeDecodeError as exc:
                raise self.make_error("invalid_utf8") from exc

        # only text is allowed
        else:
            if not isinstance(value, (str, bytes)):
                raise self.make_error("invalid")
            try:
                value = ensure_text_type(value).strip()
            except UnicodeDecodeError as exc:
                raise self.make_error("invalid_utf8") from exc

        # generate new key?
        if value.lower() == '-generate-':
            return dkim.gen_key()

        # no key or key is <hidden>?
        if not value or str(value) == '<hidden>':
            return None

        # remember part of value for ValidationError
        bad_key = value

        # strip header and footer, clean whitespace and wrap to 64 characters
        try:
            if value.startswith('-----BEGIN '):
                end = value.index('-----', 11) + 5
                header = value[:end]
                value = value[end:]
            else:
                header = '-----BEGIN PRIVATE KEY-----'

            if (pos := value.find('-----END ')) >= 0:
                end = value.index('-----', pos+9) + 5
                footer = value[pos:end]
                value = value[:pos]
            else:
                footer = '-----END PRIVATE KEY-----'
        except ValueError as exc:
            raise ValidationError(f'invalid dkim key {bad_key!r}') from exc

        # remove whitespace from key data
        value = ''.join(value.split())

        # remember part of value for ValidationError
        bad_key = f'{value[:25]}...{value[-10:]}' if len(value) > 40 else value

        # wrap key according to RFC 7468
        value = ('\n'.join(self._wrap_key(header, value, footer))).encode('ascii')

        # check key validity
        try:
            serialization.load_pem_private_key(value, password=None)
        except (UnicodeEncodeError, ValueError) as exc:
            raise ValidationError(f'invalid dkim key {bad_key!r}') from exc
        else:
            return value

class PasswordField(fields.Str):
    """ Serialize a hashed password hash by stripping the obsolete {SCHEME}
        Deserialize a plain password or hashed password into a hashed password
    """

    _hashes = {'PBKDF2', 'BLF-CRYPT', 'SHA512-CRYPT', 'SHA256-CRYPT', 'MD5-CRYPT', 'CRYPT'}

    def _serialize(self, value, attr, obj, **kwargs):
        """ strip obsolete {password-hash} when serializing """
        # strip scheme spec if in database - it's obsolete
        if value.startswith('{') and (end := value.find('}', 1)) >= 0:
            if value[1:end] in self._hashes:
                return value[end+1:]
        return value

    def _deserialize(self, value, attr, data, **kwargs):
        """ hashes plain password or checks hashed password
            also strips obsolete {password-hash} when deserializing
        """

        # when hashing is requested: use model instance to hash plain password
        if data.get('hash_password'):
            # hash password using model instance
            inst = self.metadata['model']()
            inst.set_password(value)
            value = inst.password
            del inst

        # strip scheme spec when specified - it's obsolete
        if value.startswith('{') and (end := value.find('}', 1)) >= 0:
            if value[1:end] in self._hashes:
                value = value[end+1:]

        # check if algorithm is supported
        inst = self.metadata['model'](password=value)
        try:
            # just check against empty string to see if hash is valid
            inst.check_password('')
        except ValueError as exc:
            # ValueError: hash could not be identified
            raise ValidationError(f'invalid password hash {value!r}') from exc
        del inst

        return value


### base schema ###

class Storage:
    """ Storage class to save information in context
    """

    context = {}

    def store(self, key, value):
        """ store value under key """
        key = f'{self.__class__.__name__}.{key}'
        self.context.setdefault('_track', {})[key] = value

    def recall(self, key):
        """ recall value from key """
        key = f'{self.__class__.__name__}.{key}'
        return self.context['_track'][key]

class BaseOpts(SQLAlchemyAutoSchemaOpts):
    """ Option class with sqla session
    """
    def __init__(self, meta, ordered=False):
        if not hasattr(meta, 'sqla_session'):
            meta.sqla_session = models.db.session
        if not hasattr(meta, 'sibling'):
            meta.sibling = False
        super(BaseOpts, self).__init__(meta, ordered=ordered)

class BaseSchema(ma.SQLAlchemyAutoSchema, Storage):
    """ Marshmallow base schema with custom exclude logic
        and option to hide sqla defaults
    """

    OPTIONS_CLASS = BaseOpts

    class Meta:
        """ Schema config """
        include_by_context = {}
        exclude_by_value = {}
        hide_by_context = {}
        order = []
        sibling = False

    def __init__(self, *args, **kwargs):

        # prepare only to auto-include explicitly specified attributes
        only = set(kwargs.get('only') or [])

        # get context
        context = kwargs.get('context', {})
        flags = {key for key, value in context.items() if value is True}

        # compile excludes
        exclude = set(kwargs.get('exclude', []))

        # always exclude
        exclude.update({'created_at', 'updated_at'} - only)

        # add include_by_context
        if context is not None:
            for need, what in getattr(self.Meta, 'include_by_context', {}).items():
                if not flags & set(need):
                    exclude |= what - only

        # update excludes
        kwargs['exclude'] = exclude

        # init SQLAlchemyAutoSchema
        super().__init__(*args, **kwargs)

        # exclude_by_value
        self._exclude_by_value = {
            key: values for key, values in getattr(self.Meta, 'exclude_by_value', {}).items()
            if key not in only
        }

        # exclude default values
        if not context.get('full'):
            for column in self.opts.model.__table__.columns:
                if column.name not in exclude and column.name not in only:
                    self._exclude_by_value.setdefault(column.name, []).append(
                        None if column.default is None else column.default.arg
                    )

        # hide by context
        self._hide_by_context = set()
        if context is not None:
            for need, what in getattr(self.Meta, 'hide_by_context', {}).items():
                if not flags & set(need):
                    self._hide_by_context |= what - only

        # remember primary keys
        self._primary = str(self.opts.model.__table__.primary_key.columns.values()[0].name)

        # determine attribute order
        if hasattr(self.Meta, 'order'):
            # use user-defined order
            order = self.Meta.order
        else:
            # default order is: primary_key + other keys alphabetically
            order = list(sorted(self.fields.keys()))
            if self._primary in order:
                order.remove(self._primary)
                order.insert(0, self._primary)

        # order fieldlists
        for fieldlist in (self.fields, self.load_fields, self.dump_fields):
            for field in order:
                if field in fieldlist:
                    fieldlist[field] = fieldlist.pop(field)

        # move post_load hook "_add_instance" to the end (after load_instance mixin)
        hooks = self._hooks[('post_load', False)]
        hooks.remove('_add_instance')
        hooks.append('_add_instance')

    def hide(self, data):
        """ helper method to hide input data for logging """
        # always returns a copy of data
        return {
            key: HIDDEN if key in self._hide_by_context else deepcopy(value)
            for key, value in data.items()
        }

    def get_parent(self):
        """ helper to determine parent of current object """
        for x in inspect.stack():
            loc = x[0].f_locals
            if 'ret_d' in loc:
                if isinstance(loc['self'], MailuSchema):
                    return self.context.get('config'), loc['attr_name']
                else:
                    return loc['self'].get_instance(loc['ret_d']), loc['attr_name']
        return None, None

    # this is only needed to work around the declared attr "email" primary key in model
    def get_instance(self, data):
        """ lookup item by defined primary key instead of key(s) from model """
        if self.transient:
            return None
        if keys := getattr(self.Meta, 'primary_keys', None):
            filters = {key: data.get(key) for key in keys}
            if None not in filters.values():
                try:
                    res = self.session.query(self.opts.model).filter_by(**filters).first()
                except sqlalchemy.exc.StatementError as exc:
                    raise ValidationError(f'Invalid {keys[0]}: {data.get(keys[0])!r}', data.get(keys[0])) from exc
                else:
                    return res
        res = super().get_instance(data)
        return res

    @pre_load(pass_many=True)
    def _patch_many(self, items, many, **kwargs): # pylint: disable=unused-argument
        """ - flush sqla session before serializing a section when requested
              (make sure all objects that could be referred to later are created)
            - when in update mode: patch input data before deserialization
              - handle "prune" and "delete" items
              - replace values in keys starting with '-' with default
        """

        # flush sqla session
        if not self.Meta.sibling:
            self.opts.sqla_session.flush()

        # stop early when not updating
        if not self.context.get('update'):
            return items

        # patch "delete", "prune" and "default"
        want_prune = []
        def patch(count, data):

            # we only process objects here
            if type(data) is not dict:
                raise ValidationError(f'Invalid item. {self.Meta.model.__tablename__.title()} needs to be an object.', f'{data!r}')

            # don't allow __delete__ coming from input
            if '__delete__' in data:
                raise ValidationError('Unknown field.', f'{count}.__delete__')

            # fail when hash_password is specified without password
            if 'hash_password' in data and not 'password' in data:
                raise ValidationError(
                    'Nothing to hash. Field "password" is missing.',
                    field_name = f'{count}.hash_password',
                )

            # handle "prune list" and "delete item" (-pkey: none and -pkey: id)
            for key in data:
                if key.startswith('-'):
                    if key[1:] == self._primary:
                        # delete or prune
                        if data[key] is None:
                            # prune
                            want_prune.append(True)
                            return None
                        # mark item for deletion
                        return {key[1:]: data[key], '__delete__': count}

            # handle "set to default value" (-key: none)
            def set_default(key, value):
                if not key.startswith('-'):
                    return (key, value)
                key = key[1:]
                if not key in self.opts.model.__table__.columns:
                    return (key, None)
                if value is not None:
                    raise ValidationError(
                        'Value must be "null" when resetting to default.',
                        f'{count}.{key}'
                    )
                value = self.opts.model.__table__.columns[key].default
                if value is None:
                    raise ValidationError(
                        'Field has no default value.',
                        f'{count}.{key}'
                    )
                return (key, value.arg)

            return dict(set_default(key, value) for key, value in data.items())

        # convert items to "delete" and filter "prune" item
        items = [
            item for item in [
                patch(count, item) for count, item in enumerate(items)
            ] if item
        ]

        # remember if prune was requested for _prune_items@post_load
        self.store('prune', bool(want_prune))

        # remember original items to stabilize password-changes in _add_instance@post_load
        self.store('original', items)

        return items

    @pre_load
    def _patch_item(self, data, many, **kwargs): # pylint: disable=unused-argument
        """ - call callback function to track import
            - stabilize import of items with auto-increment primary key
            - delete items
            - delete/prune list attributes
            - add missing required attributes
        """

        # callback
        if callback := self.context.get('callback'):
            callback(self, data)

        # stop early when not updating
        if not self.opts.load_instance or not self.context.get('update'):
            return data

        # stabilize import of auto-increment primary keys (not required),
        # by matching import data to existing items and setting primary key
        if not self._primary in data:
            parent, field = self.get_parent()
            if parent is not None:
                for item in getattr(parent, field):
                    existing = self.dump(item, many=False)
                    this = existing.pop(self._primary)
                    if data == existing:
                        self.instance = item
                        data[self._primary] = this
                        break

        # try to load instance
        instance = self.instance or self.get_instance(data)
        if instance is None:

            if '__delete__' in data:
                # deletion of non-existent item requested
                raise ValidationError(
                    f'Item to delete not found: {data[self._primary]!r}.',
                    field_name = f'{data["__delete__"]}.{self._primary}',
                )

        else:

            if self.context.get('update'):
                # delete instance from session when marked
                if '__delete__' in data:
                    self.opts.sqla_session.delete(instance)
                # delete item from lists or prune lists
                # currently: domain.alternatives, user.forward_destination,
                # user.manager_of, aliases.destination
                for key, value in data.items():
                    if not isinstance(self.fields.get(key), (
                        RelatedList, CommaSeparatedListField, fields.Raw)
                    ) or not isinstance(value, list):
                        continue
                    # deduplicate new value
                    new_value = set(value)
                    # handle list pruning
                    if '-prune-' in value:
                        value.remove('-prune-')
                        new_value.remove('-prune-')
                    else:
                        for old in getattr(instance, key):
                            # using str() is okay for now (see above)
                            new_value.add(str(old))
                    # handle item deletion
                    for item in value:
                        if item.startswith('-'):
                            new_value.remove(item)
                            try:
                                new_value.remove(item[1:])
                            except KeyError as exc:
                                raise ValidationError(
                                    f'Item to delete not found: {item[1:]!r}.',
                                    field_name=f'?.{key}',
                                ) from exc
                    # sort list of new values
                    data[key] = sorted(new_value)
                    # log backref modification not caught by modify hook
                    if isinstance(self.fields[key], RelatedList):
                        if callback := self.context.get('callback'):
                            before = {str(v) for v in getattr(instance, key)}
                            after = set(data[key])
                            if before != after:
                                callback(self, instance, {
                                    'key': key,
                                    'target': str(instance),
                                    'before': before,
                                    'after': after,
                                })

            # add attributes required for validation from db
            for attr_name, field_obj in self.load_fields.items():
                if field_obj.required and attr_name not in data:
                    data[attr_name] = getattr(instance, attr_name)

        return data

    @post_load(pass_many=True)
    def _prune_items(self, items, many, **kwargs): # pylint: disable=unused-argument
        """ handle list pruning """

        # stop early when not updating
        if not self.context.get('update'):
            return items

        # get prune flag from _patch_many@pre_load
        want_prune = self.recall('prune')

        # prune: determine if existing items in db need to be added or marked for deletion
        add_items = False
        del_items = False
        if self.Meta.sibling:
            # parent prunes automatically
            if not want_prune:
                # no prune requested => add old items
                add_items = True
        else:
            # parent does not prune automatically
            if want_prune:
                # prune requested => mark old items for deletion
                del_items = True

        if add_items or del_items:
            parent, field = self.get_parent()
            if parent is not None:
                existing = {item[self._primary] for item in items if self._primary in item}
                for item in getattr(parent, field):
                    key = getattr(item, self._primary)
                    if key not in existing:
                        if add_items:
                            items.append({self._primary: key})
                        else:
                            if self.context.get('update'):
                                self.opts.sqla_session.delete(self.instance or self.get_instance({self._primary: key}))

        return items

    @post_load
    def _add_instance(self, item, many, **kwargs): # pylint: disable=unused-argument
        """ - undo password change in existing instances when plain password did not change
            - add new instances to sqla session
        """

        if not item in self.opts.sqla_session:
            self.opts.sqla_session.add(item)
            return item

        # stop early when not updating or item has no password attribute
        if not self.context.get('update') or not hasattr(item, 'password'):
            return item

        # did we hash a new plaintext password?
        original = None
        pkey = getattr(item, self._primary)
        for data in self.recall('original'):
            if 'hash_password' in data and data.get(self._primary) == pkey:
                original = data['password']
                break
        if original is None:
            # password was hashed by us
            return item

        # reset hash if plain password matches hash from db
        if attr := getattr(sqlalchemy.inspect(item).attrs, 'password', None):
            if attr.history.has_changes() and attr.history.deleted:
                try:
                    # reset password hash
                    inst = type(item)(password=attr.history.deleted[-1])
                    if inst.check_password(original):
                        item.password = inst.password
                except ValueError:
                    # hash in db is invalid
                    pass
                else:
                    del inst

        return item

    @post_dump
    def _hide_values(self, data, many, **kwargs): # pylint: disable=unused-argument
        """ hide secrets """

        # stop early when not excluding/hiding
        if not self._exclude_by_value and not self._hide_by_context:
            return data

        # exclude or hide values
        full = self.context.get('full')
        return type(data)(
            (key, HIDDEN if key in self._hide_by_context else value)
            for key, value in data.items()
            if full or key not in self._exclude_by_value or value not in self._exclude_by_value[key]
        )

    # this field is used to mark items for deletion
    mark_delete = fields.Boolean(data_key='__delete__', load_only=True)

    # TODO: this can be removed when comment is not nullable in model
    comment = LazyStringField()


### schema definitions ###

@mapped
class DomainSchema(BaseSchema):
    """ Marshmallow schema for Domain model """
    class Meta:
        """ Schema config """
        model = models.Domain
        load_instance = True
        include_relationships = True
        exclude = ['users', 'managers', 'aliases']

        include_by_context = {
            ('dns',): {'dkim_publickey', 'dns_mx', 'dns_spf', 'dns_dkim', 'dns_dmarc'},
        }
        hide_by_context = {
            ('secrets',): {'dkim_key'},
        }
        exclude_by_value = {
            'alternatives': [[]],
            'dkim_key': [None],
            'dkim_publickey': [None],
            'dns_mx': [None],
            'dns_spf': [None],
            'dns_dkim': [None],
            'dns_dmarc': [None],
        }

    dkim_key = DkimKeyField(allow_none=True)
    dkim_publickey = fields.String(dump_only=True)
    dns_mx = fields.String(dump_only=True)
    dns_spf = fields.String(dump_only=True)
    dns_dkim = fields.String(dump_only=True)
    dns_dmarc = fields.String(dump_only=True)


@mapped
class TokenSchema(BaseSchema):
    """ Marshmallow schema for Token model """
    class Meta:
        """ Schema config """
        model = models.Token
        load_instance = True

        sibling = True

    password = PasswordField(required=True, metadata={'model': models.User})
    hash_password = fields.Boolean(load_only=True, load_default=False)


@mapped
class FetchSchema(BaseSchema):
    """ Marshmallow schema for Fetch model """
    class Meta:
        """ Schema config """
        model = models.Fetch
        load_instance = True

        sibling = True
        include_by_context = {
            ('full', 'import'): {'last_check', 'error'},
        }
        hide_by_context = {
            ('secrets',): {'password'},
        }


@mapped
class UserSchema(BaseSchema):
    """ Marshmallow schema for User model """
    class Meta:
        """ Schema config """
        model = models.User
        load_instance = True
        include_relationships = True
        exclude = ['_email', 'domain', 'localpart', 'domain_name', 'quota_bytes_used']

        primary_keys = ['email']
        exclude_by_value = {
            'forward_destination': [[]],
            'tokens':              [[]],
            'fetches':             [[]],
            'manager_of':          [[]],
            'reply_enddate':       ['2999-12-31'],
            'reply_startdate':     ['1900-01-01'],
        }

    email = fields.String(required=True)
    tokens = fields.Nested(TokenSchema, many=True)
    fetches = fields.Nested(FetchSchema, many=True)

    password = PasswordField(required=True, metadata={'model': models.User})
    hash_password = fields.Boolean(load_only=True, load_default=False)


@mapped
class AliasSchema(BaseSchema):
    """ Marshmallow schema for Alias model """
    class Meta:
        """ Schema config """
        model = models.Alias
        load_instance = True
        exclude = ['_email', 'domain', 'localpart', 'domain_name']

        primary_keys = ['email']
        exclude_by_value = {
            'destination': [[]],
        }

    email = fields.String(required=True)
    destination = CommaSeparatedListField()


@mapped
class ConfigSchema(BaseSchema):
    """ Marshmallow schema for Config model """
    class Meta:
        """ Schema config """
        model = models.Config
        load_instance = True


@mapped
class RelaySchema(BaseSchema):
    """ Marshmallow schema for Relay model """
    class Meta:
        """ Schema config """
        model = models.Relay
        load_instance = True


@mapped
class MailuSchema(Schema, Storage):
    """ Marshmallow schema for complete Mailu config """
    class Meta:
        """ Schema config """
        model = models.MailuConfig
        render_module = RenderYAML

        order = ['domain', 'user', 'alias', 'relay'] # 'config'

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # order fieldlists
        for fieldlist in (self.fields, self.load_fields, self.dump_fields):
            for field in self.Meta.order:
                if field in fieldlist:
                    fieldlist[field] = fieldlist.pop(field)

    @pre_load
    def _clear_config(self, data, many, **kwargs): # pylint: disable=unused-argument
        """ create config object in context if missing
            and clear it if requested
        """
        if 'config' not in self.context:
            self.context['config'] = models.MailuConfig()
        if self.context.get('clear'):
            self.context['config'].clear(
                models = {field.nested.opts.model for field in self.fields.values()}
            )
        return data

    @post_load
    def _make_config(self, data, many, **kwargs): # pylint: disable=unused-argument
        """ update and return config object """
        config = self.context['config']
        for section in self.Meta.order:
            if section in data:
                config.update(data[section], section)

        return config

    domain = fields.Nested(DomainSchema, many=True)
    user = fields.Nested(UserSchema, many=True)
    alias = fields.Nested(AliasSchema, many=True)
    relay = fields.Nested(RelaySchema, many=True)
#    config = fields.Nested(ConfigSchema, many=True)
